{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Sample selection in NearMiss\n\nThis example illustrates the different way of selecting example in\n:class:`~imblearn.under_sampling.NearMiss`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>\n# License: MIT"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(__doc__)\n\nimport seaborn as sns\n\nsns.set_context(\"poster\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We define a function allowing to make some nice decoration on the plot.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def make_plot_despine(ax):\n    sns.despine(ax=ax, offset=10)\n    ax.set_xlim([0, 3.5])\n    ax.set_ylim([0, 3.5])\n    ax.set_xticks(np.arange(0, 3.6, 0.5))\n    ax.set_yticks(np.arange(0, 3.6, 0.5))\n    ax.set_xlabel(r\"$X_1$\")\n    ax.set_ylabel(r\"$X_2$\")\n    ax.legend(loc=\"upper left\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can start by generating some data to later illustrate the principle of\neach :class:`~imblearn.under_sampling.NearMiss` heuristic rules.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n\nrng = np.random.RandomState(18)\n\nX_minority = np.transpose(\n    [[1.1, 1.3, 1.15, 0.8, 0.8, 0.6, 0.55], [1.0, 1.5, 1.7, 2.5, 2.0, 1.2, 0.55]]\n)\nX_majority = np.transpose(\n    [\n        [2.1, 2.12, 2.13, 2.14, 2.2, 2.3, 2.5, 2.45],\n        [1.5, 2.1, 2.7, 0.9, 1.0, 1.4, 2.4, 2.9],\n    ]\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NearMiss-1\n\nNearMiss-1 selects samples from the majority class for which the average\ndistance to some nearest neighbours is the smallest. In the following\nexample, we use a 3-NN to compute the average distance on 2 specific samples\nof the majority class. Therefore, in this case the point linked by the\ngreen-dashed line will be selected since the average distance is smaller.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nfrom sklearn.neighbors import NearestNeighbors\n\nfig, ax = plt.subplots(figsize=(8, 8))\nax.scatter(\n    X_minority[:, 0],\n    X_minority[:, 1],\n    label=\"Minority class\",\n    s=200,\n    marker=\"_\",\n)\nax.scatter(\n    X_majority[:, 0],\n    X_majority[:, 1],\n    label=\"Majority class\",\n    s=200,\n    marker=\"+\",\n)\n\nnearest_neighbors = NearestNeighbors(n_neighbors=3)\nnearest_neighbors.fit(X_minority)\ndist, ind = nearest_neighbors.kneighbors(X_majority[:2, :])\ndist_avg = dist.sum(axis=1) / 3\n\nfor positive_idx, (neighbors, distance, color) in enumerate(\n    zip(ind, dist_avg, [\"g\", \"r\"])\n):\n    for make_plot, sample_idx in enumerate(neighbors):\n        ax.plot(\n            [X_majority[positive_idx, 0], X_minority[sample_idx, 0]],\n            [X_majority[positive_idx, 1], X_minority[sample_idx, 1]],\n            \"--\" + color,\n            alpha=0.3,\n            label=f\"Avg. dist.={distance:.2f}\" if make_plot == 0 else \"\",\n        )\nax.set_title(\"NearMiss-1\")\nmake_plot_despine(ax)\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NearMiss-2\n\nNearMiss-2 selects samples from the majority class for which the average\ndistance to the farthest neighbors is the smallest. With the same\nconfiguration as previously presented, the sample linked to the green-dashed\nline will be selected since its distance the 3 farthest neighbors is the\nsmallest.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(figsize=(8, 8))\nax.scatter(\n    X_minority[:, 0],\n    X_minority[:, 1],\n    label=\"Minority class\",\n    s=200,\n    marker=\"_\",\n)\nax.scatter(\n    X_majority[:, 0],\n    X_majority[:, 1],\n    label=\"Majority class\",\n    s=200,\n    marker=\"+\",\n)\n\nnearest_neighbors = NearestNeighbors(n_neighbors=X_minority.shape[0])\nnearest_neighbors.fit(X_minority)\ndist, ind = nearest_neighbors.kneighbors(X_majority[:2, :])\ndist = dist[:, -3::]\nind = ind[:, -3::]\ndist_avg = dist.sum(axis=1) / 3\n\nfor positive_idx, (neighbors, distance, color) in enumerate(\n    zip(ind, dist_avg, [\"g\", \"r\"])\n):\n    for make_plot, sample_idx in enumerate(neighbors):\n        ax.plot(\n            [X_majority[positive_idx, 0], X_minority[sample_idx, 0]],\n            [X_majority[positive_idx, 1], X_minority[sample_idx, 1]],\n            \"--\" + color,\n            alpha=0.3,\n            label=f\"Avg. dist.={distance:.2f}\" if make_plot == 0 else \"\",\n        )\nax.set_title(\"NearMiss-2\")\nmake_plot_despine(ax)\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NearMiss-3\n\nNearMiss-3 can be divided into 2 steps. First, a nearest-neighbors is used to\nshort-list samples from the majority class (i.e. correspond to the\nhighlighted samples in the following plot). Then, the sample with the largest\naverage distance to the *k* nearest-neighbors are selected.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(figsize=(8.5, 8.5))\nax.scatter(\n    X_minority[:, 0],\n    X_minority[:, 1],\n    label=\"Minority class\",\n    s=200,\n    marker=\"_\",\n)\nax.scatter(\n    X_majority[:, 0],\n    X_majority[:, 1],\n    label=\"Majority class\",\n    s=200,\n    marker=\"+\",\n)\n\nnearest_neighbors = NearestNeighbors(n_neighbors=3)\nnearest_neighbors.fit(X_majority)\n\n# select only the majority point of interest\nselected_idx = nearest_neighbors.kneighbors(X_minority, return_distance=False)\nX_majority = X_majority[np.unique(selected_idx), :]\nax.scatter(\n    X_majority[:, 0],\n    X_majority[:, 1],\n    label=\"Short-listed samples\",\n    s=200,\n    alpha=0.3,\n    color=\"g\",\n)\nnearest_neighbors = NearestNeighbors(n_neighbors=3)\nnearest_neighbors.fit(X_minority)\ndist, ind = nearest_neighbors.kneighbors(X_majority[:2, :])\ndist_avg = dist.sum(axis=1) / 3\n\nfor positive_idx, (neighbors, distance, color) in enumerate(\n    zip(ind, dist_avg, [\"r\", \"g\"])\n):\n    for make_plot, sample_idx in enumerate(neighbors):\n        ax.plot(\n            [X_majority[positive_idx, 0], X_minority[sample_idx, 0]],\n            [X_majority[positive_idx, 1], X_minority[sample_idx, 1]],\n            \"--\" + color,\n            alpha=0.3,\n            label=f\"Avg. dist.={distance:.2f}\" if make_plot == 0 else \"\",\n        )\nax.set_title(\"NearMiss-3\")\nmake_plot_despine(ax)\nplt.tight_layout()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}